# -*- coding: utf-8 -*-
"""
Created on Sat Apr 20 20:43:55 2019

@author: xiang_yaobing
功能：调用DBscan方法获取候选团，并绘制候选团的三个分布图像，获取核密度估计图像。
设计主要算法：
1、dbscan
2、球树，knn算法
3、最小值寻找算法
"""

import numpy as np
import data_genrate_1 as dg
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
import os
import balltree as bt
import seaborn as sns

def example_dbscan(): 
    data = dg.get_fits_to_array('E:\\天文\\cluster_data\\cluster_normal\\berkeley10.r0.2.fits') 
    data_measure = data[:,[0,1,2,3,5]]
    loc, pm = dg.get_data_config(data)
    #dg.cluster_subplot(data, loc, pm)
    X = StandardScaler().fit_transform(data_measure) # StandardScaler作用：去均值和方差归一化。且是针对每一个特征维度来做的，而不是针对样本。
    # #############################################################################
    # 调用密度聚类  DBSCAN
    db = DBSCAN(eps=0.3, min_samples=7).fit(X)
    core_samples_mask = np.zeros_like(db.labels_, dtype=bool)  # 设置一个样本个数长度的全false向量
    core_samples_mask[db.core_sample_indices_] = True #将核心样本部分设置为true
    labels = db.labels_
    # 获取聚类个数。（聚类结果中-1表示没有聚类为离散点）
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    print('n_cluster = ', n_clusters_)
    data[:,4] = labels
    dg.cluster_subplot(data, loc, pm)
    dg.CMD_plot(data)
    cmd_data = np.array(data[data[:, 4] != -1])
    only_cluster_kde_plot(cmd_data)
    return cmd_data

def cluster_dbscan(file_path, save_path):
    '''
    通过DBscan方法获取候选团星，并保存到指定文件夹下
    '''
    data = dg.get_fits_to_array(file_path) 
    data_measure = data[:,[0,1,2,3,5]]
    loc, pm = dg.get_data_config(data)
    #dg.cluster_subplot(data, loc, pm)
    X = StandardScaler().fit_transform(data_measure) # StandardScaler作用：去均值和方差归一化。且是针对每一个特征维度来做的，而不是针对样本。
    # #############################################################################
    # 调用密度聚类  DBSCAN
    split_point = bt.get_data_split_config(X)[0][0]
    print('split_point=', split_point)
    db = DBSCAN(eps=split_point, min_samples=7).fit(X)
    core_samples_mask = np.zeros_like(db.labels_, dtype=bool)  # 设置一个样本个数长度的全false向量
    core_samples_mask[db.core_sample_indices_] = True #将核心样本部分设置为true
    labels = db.labels_
    # 获取聚类个数。（聚类结果中-1表示没有聚类为离散点）
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    file_name = file_path.split('\\')[-1]
    print('cluster num =',n_clusters_,' file name:',file_name)
    data[:,4] = labels
    #dg.cluster_subplot(data, loc, pm)
    cmd_plot_save_path = save_path+'scatter\\'+str(file_path.split('\\')[-1])+'.jpg'
    cmd_kde_plot_save_path = save_path+'kde\\'+str(file_path.split('\\')[-1])+'.jpg'
    dg.CMD_subplot_save(data, cmd_plot_save_path)
    dg.plt.show()
    only_cluster_kde_plot(data, cmd_kde_plot_save_path)
    
    
    


def deal_file_fits(file_path,save_path):
    '''
    实现对文件夹的操作及建立路径
    '''
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    new_dbscan_path = save_path 
    dg.mkdir(new_dbscan_path+'scatter\\')
    dg.mkdir(new_dbscan_path+'kde\\')
    #dg.mkdir(new_dbscan_path)
    for fit in filelist_open:
        print(fit)
        picture_save_path = new_dbscan_path 
        cluster_dbscan(fit, picture_save_path)
        
def get_cmd_config(cmd_data):#获取cmd参数
    BP_PR_max = np.max(cmd_data[0,:])
    BP_PR_min = np.min(cmd_data[0,:])
    gamg_max = np.max(cmd_data[1,:])
    gamg_min = np.min(cmd_data[1,:])
    cmd_config = [BP_PR_min, BP_PR_max,gamg_min, gamg_max]
    return cmd_config

def scipy_kde(cmd_data_trans, cmd_config): 
    '''
    way1:描述cmd
    
    '''
    kde = dg.gaussian_kde(cmd_data_trans)
    xgrid = np.linspace(cmd_config[0], cmd_config[1], 60)
    ygrid = np.linspace(cmd_config[2], cmd_config[3], 60)
    xgrid,ygrid = np.meshgrid(xgrid,ygrid)
    z = kde.evaluate(np.vstack([xgrid.ravel(),ygrid.ravel()]))
    dg.plt.imshow(z.reshape(xgrid.shape),origin='lower', aspect='auto',extent = cmd_config,cmap='Reds')

def matplotlab_plt_hist2d(cmd_data_trans):
    
    '''
    way2：描述cmd
    '''
    dg.plt.hist2d(cmd_data_trans[0,:], cmd_data_trans[1,:],bins=28)
    cb = dg.plt.colorbar()
    cb.set_label('counts in bins')

def sns_kdeplot(cmd_data_trans):
    '''
    way3：描述cmd
    '''
    dg.plt.subplots(figsize=(6, 6))
    cmap = sns.cubehelix_palette(as_cmap=True, dark=0, light=1, reverse=True)
    sns.kdeplot(cmd_data_trans[0, :], cmd_data_trans[1, :], cmap=cmap, n_levels=60, shade=True)
    
    
def only_cluster_kde_plot(data, save_path):#绘制cmd密度图
    cmd_data = data[data[:, 4] != -1]
    cmd_data_trans = np.array([cmd_data[:, 7], -cmd_data[:, 6]])
    cmd_config = get_cmd_config(cmd_data_trans)
    print(cmd_config)
    sns_kdeplot(cmd_data_trans)   
    dg.plt.savefig(save_path)
    dg.plt.show()
    return 0

def dbscan_main(): 
    file_path = 'E:\\天文\\cluster_data\\正负样本\\no cluster\\'
    save_path = 'E:\\天文\\cluster_data\\dbscan_no_cluster_picture\\'
    deal_file_fits(file_path, save_path)
    
'''
data = dg.get_fits_to_array('E:\\天文\\cluster_data\\cluster_normal\\NGC6819.R0.3.fits') 
#data = dg.get_fits_to_array('E:\\天文\\cluster_data\\正负样本\\no cluster\\1.fits') 
data_measure = data[:,[0,1,2,3,5]]
loc, pm = dg.get_data_config(data)
#dg.cluster_subplot(data, loc, pm)
X = StandardScaler().fit_transform(data_measure) # StandardScaler作用：去均值和方差归一化。且是针对每一个特征维度来做的，而不是针对样本。
# #############################################################################
# 调用密度聚类  DBSCAN
split_point = bt.get_data_split_config(X)[0][0]
db = DBSCAN(eps=split_point, min_samples=7).fit(X)
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)  # 设置一个样本个数长度的全false向量
core_samples_mask[db.core_sample_indices_] = True #将核心样本部分设置为true
labels = db.labels_
# 获取聚类个数。（聚类结果中-1表示没有聚类为离散点）
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
print('n_cluster = ', n_clusters_)
data[:,4] = labels
#dg.cluster_subplot(data, loc, pm)
dg.CMD_plot(data)
dg.plt.show()
cmd_data1 = np.array(data[data[:, 4] != -1])
print('hello', cmd_data1.dtype)
#test = only_cluster_kde_plot(cmd_data1)
dg.plt.show()
dg.plt.figure(figsize=(6, 6))
sns.kdeplot(cmd_data1[:, 7],-cmd_data1[:, 6], bw='silverman', n_levels=6, shade=True)
dg.plt.show()
cmd1 = np.array([cmd_data1[:, 7], -cmd_data1[:, 6]])

dg.plt.subplots(figsize=(7, 6))
cmap = sns.cubehelix_palette(as_cmap=True, dark=0, light=1, reverse=True)
sns.kdeplot(cmd_data1[:, 7],-cmd_data1[:, 6], cmap=cmap, n_levels=60, shade=True, cbar=True)
#dg.plt.scatter(data[:,7], -data[:,6])
#dg.plt.scatter(cmd1[0,:], -cmd1[1,:])
'''